{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {
    "vscode": {
     "languageId": "markdown"
    }
   },
   "source": [
    "# Relevant Segment Extraction (RSE) for Enhanced RAG\n",
    "\n",
    "In this notebook, we implement a Relevant Segment Extraction (RSE) technique to improve the context quality in our RAG system. Rather than simply retrieving a collection of isolated chunks, we identify and reconstruct continuous segments of text that provide better context to our language model.\n",
    "\n",
    "## Key Concept\n",
    "\n",
    "Relevant chunks tend to be clustered together within documents. By identifying these clusters and preserving their continuity, we provide more coherent context for the LLM to work with."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Setting Up the Environment\n",
    "We begin by importing necessary libraries."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "import fitz\n",
    "import os\n",
    "import numpy as np\n",
    "import json\n",
    "from openai import OpenAI\n",
    "import re"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Extracting Text from a PDF File\n",
    "To implement RAG, we first need a source of textual data. In this case, we extract text from a PDF file using the PyMuPDF library."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "def extract_text_from_pdf(pdf_path):\n",
    "    \"\"\"\n",
    "    Extracts text from a PDF file and prints the first `num_chars` characters.\n",
    "\n",
    "    Args:\n",
    "    pdf_path (str): Path to the PDF file.\n",
    "\n",
    "    Returns:\n",
    "    str: Extracted text from the PDF.\n",
    "    \"\"\"\n",
    "    # Open the PDF file\n",
    "    mypdf = fitz.open(pdf_path)\n",
    "    all_text = \"\"  # Initialize an empty string to store the extracted text\n",
    "\n",
    "    # Iterate through each page in the PDF\n",
    "    for page_num in range(mypdf.page_count):\n",
    "        page = mypdf[page_num]  # Get the page\n",
    "        text = page.get_text(\"text\")  # Extract text from the page\n",
    "        all_text += text  # Append the extracted text to the all_text string\n",
    "\n",
    "    return all_text  # Return the extracted text"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Chunking the Extracted Text\n",
    "Once we have the extracted text, we divide it into smaller, overlapping chunks to improve retrieval accuracy."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "def chunk_text(text, chunk_size=800, overlap=0):\n",
    "    \"\"\"\n",
    "    Split text into non-overlapping chunks.\n",
    "    For RSE, we typically want non-overlapping chunks so we can reconstruct segments properly.\n",
    "    \n",
    "    Args:\n",
    "        text (str): Input text to chunk\n",
    "        chunk_size (int): Size of each chunk in characters\n",
    "        overlap (int): Overlap between chunks in characters\n",
    "        \n",
    "    Returns:\n",
    "        List[str]: List of text chunks\n",
    "    \"\"\"\n",
    "    chunks = []\n",
    "    \n",
    "    # Simple character-based chunking\n",
    "    for i in range(0, len(text), chunk_size - overlap):\n",
    "        chunk = text[i:i + chunk_size]\n",
    "        if chunk:  # Ensure we don't add empty chunks\n",
    "            chunks.append(chunk)\n",
    "    \n",
    "    return chunks"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Setting Up the OpenAI API Client\n",
    "We initialize the OpenAI client to generate embeddings and responses."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Initialize the OpenAI client with the base URL and API key\n",
    "client = OpenAI(\n",
    "    base_url=\"https://api.studio.nebius.com/v1/\",\n",
    "    api_key=os.getenv(\"OPENAI_API_KEY\")  # Retrieve the API key from environment variables\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Building a Simple Vector Store\n",
    "let's implement a simple vector store."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [],
   "source": [
    "class SimpleVectorStore:\n",
    "    \"\"\"\n",
    "    A lightweight vector store implementation using NumPy.\n",
    "    \"\"\"\n",
    "    def __init__(self, dimension=1536):\n",
    "        \"\"\"\n",
    "        Initialize the vector store.\n",
    "        \n",
    "        Args:\n",
    "            dimension (int): Dimension of embeddings\n",
    "        \"\"\"\n",
    "        self.dimension = dimension\n",
    "        self.vectors = []\n",
    "        self.documents = []\n",
    "        self.metadata = []\n",
    "    \n",
    "    def add_documents(self, documents, vectors=None, metadata=None):\n",
    "        \"\"\"\n",
    "        Add documents to the vector store.\n",
    "        \n",
    "        Args:\n",
    "            documents (List[str]): List of document chunks\n",
    "            vectors (List[List[float]], optional): List of embedding vectors\n",
    "            metadata (List[Dict], optional): List of metadata dictionaries\n",
    "        \"\"\"\n",
    "        if vectors is None:\n",
    "            vectors = [None] * len(documents)\n",
    "        \n",
    "        if metadata is None:\n",
    "            metadata = [{} for _ in range(len(documents))]\n",
    "        \n",
    "        for doc, vec, meta in zip(documents, vectors, metadata):\n",
    "            self.documents.append(doc)\n",
    "            self.vectors.append(vec)\n",
    "            self.metadata.append(meta)\n",
    "    \n",
    "    def search(self, query_vector, top_k=5):\n",
    "        \"\"\"\n",
    "        Search for most similar documents.\n",
    "        \n",
    "        Args:\n",
    "            query_vector (List[float]): Query embedding vector\n",
    "            top_k (int): Number of results to return\n",
    "            \n",
    "        Returns:\n",
    "            List[Dict]: List of results with documents, scores, and metadata\n",
    "        \"\"\"\n",
    "        if not self.vectors or not self.documents:\n",
    "            return []\n",
    "        \n",
    "        # Convert query vector to numpy array\n",
    "        query_array = np.array(query_vector)\n",
    "        \n",
    "        # Calculate similarities\n",
    "        similarities = []\n",
    "        for i, vector in enumerate(self.vectors):\n",
    "            if vector is not None:\n",
    "                # Compute cosine similarity\n",
    "                similarity = np.dot(query_array, vector) / (\n",
    "                    np.linalg.norm(query_array) * np.linalg.norm(vector)\n",
    "                )\n",
    "                similarities.append((i, similarity))\n",
    "        \n",
    "        # Sort by similarity (descending)\n",
    "        similarities.sort(key=lambda x: x[1], reverse=True)\n",
    "        \n",
    "        # Get top-k results\n",
    "        results = []\n",
    "        for i, score in similarities[:top_k]:\n",
    "            results.append({\n",
    "                \"document\": self.documents[i],\n",
    "                \"score\": float(score),\n",
    "                \"metadata\": self.metadata[i]\n",
    "            })\n",
    "        \n",
    "        return results"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Creating Embeddings for Text Chunks\n",
    "Embeddings transform text into numerical vectors, which allow for efficient similarity search."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [],
   "source": [
    "def create_embeddings(texts, model=\"BAAI/bge-en-icl\"):\n",
    "    \"\"\"\n",
    "    Generate embeddings for texts.\n",
    "    \n",
    "    Args:\n",
    "        texts (List[str]): List of texts to embed\n",
    "        model (str): Embedding model to use\n",
    "        \n",
    "    Returns:\n",
    "        List[List[float]]: List of embedding vectors\n",
    "    \"\"\"\n",
    "    if not texts:\n",
    "        return []  # Return an empty list if no texts are provided\n",
    "        \n",
    "    # Process in batches if the list is long\n",
    "    batch_size = 100  # Adjust based on your API limits\n",
    "    all_embeddings = []  # Initialize a list to store all embeddings\n",
    "    \n",
    "    for i in range(0, len(texts), batch_size):\n",
    "        batch = texts[i:i + batch_size]  # Get the current batch of texts\n",
    "        \n",
    "        # Create embeddings for the current batch using the specified model\n",
    "        response = client.embeddings.create(\n",
    "            input=batch,\n",
    "            model=model\n",
    "        )\n",
    "        \n",
    "        # Extract embeddings from the response\n",
    "        batch_embeddings = [item.embedding for item in response.data]\n",
    "        all_embeddings.extend(batch_embeddings)  # Add the batch embeddings to the list\n",
    "        \n",
    "    return all_embeddings  # Return the list of all embeddings"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Processing Documents with RSE\n",
    "Now let's implement the core RSE functionality."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [],
   "source": [
    "def process_document(pdf_path, chunk_size=800):\n",
    "    \"\"\"\n",
    "    Process a document for use with RSE.\n",
    "    \n",
    "    Args:\n",
    "        pdf_path (str): Path to the PDF document\n",
    "        chunk_size (int): Size of each chunk in characters\n",
    "        \n",
    "    Returns:\n",
    "        Tuple[List[str], SimpleVectorStore, Dict]: Chunks, vector store, and document info\n",
    "    \"\"\"\n",
    "    print(\"Extracting text from document...\")\n",
    "    # Extract text from the PDF file\n",
    "    text = extract_text_from_pdf(pdf_path)\n",
    "    \n",
    "    print(\"Chunking text into non-overlapping segments...\")\n",
    "    # Chunk the extracted text into non-overlapping segments\n",
    "    chunks = chunk_text(text, chunk_size=chunk_size, overlap=0)\n",
    "    print(f\"Created {len(chunks)} chunks\")\n",
    "    \n",
    "    print(\"Generating embeddings for chunks...\")\n",
    "    # Generate embeddings for the text chunks\n",
    "    chunk_embeddings = create_embeddings(chunks)\n",
    "    \n",
    "    # Create an instance of the SimpleVectorStore\n",
    "    vector_store = SimpleVectorStore()\n",
    "    \n",
    "    # Add documents with metadata (including chunk index for later reconstruction)\n",
    "    metadata = [{\"chunk_index\": i, \"source\": pdf_path} for i in range(len(chunks))]\n",
    "    vector_store.add_documents(chunks, chunk_embeddings, metadata)\n",
    "    \n",
    "    # Track original document structure for segment reconstruction\n",
    "    doc_info = {\n",
    "        \"chunks\": chunks,\n",
    "        \"source\": pdf_path,\n",
    "    }\n",
    "    \n",
    "    return chunks, vector_store, doc_info"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## RSE Core Algorithm: Computing Chunk Values and Finding Best Segments\n",
    "Now that we have the necessary functions to process a document and generate embeddings for its chunks, we can implement the core algorithm for RSE. "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [],
   "source": [
    "def calculate_chunk_values(query, chunks, vector_store, irrelevant_chunk_penalty=0.2):\n",
    "    \"\"\"\n",
    "    Calculate chunk values by combining relevance and position.\n",
    "    \n",
    "    Args:\n",
    "        query (str): Query text\n",
    "        chunks (List[str]): List of document chunks\n",
    "        vector_store (SimpleVectorStore): Vector store containing the chunks\n",
    "        irrelevant_chunk_penalty (float): Penalty for irrelevant chunks\n",
    "        \n",
    "    Returns:\n",
    "        List[float]: List of chunk values\n",
    "    \"\"\"\n",
    "    # Create query embedding\n",
    "    query_embedding = create_embeddings([query])[0]\n",
    "    \n",
    "    # Get all chunks with similarity scores\n",
    "    num_chunks = len(chunks)\n",
    "    results = vector_store.search(query_embedding, top_k=num_chunks)\n",
    "    \n",
    "    # Create a mapping of chunk_index to relevance score\n",
    "    relevance_scores = {result[\"metadata\"][\"chunk_index\"]: result[\"score\"] for result in results}\n",
    "    \n",
    "    # Calculate chunk values (relevance score minus penalty)\n",
    "    chunk_values = []\n",
    "    for i in range(num_chunks):\n",
    "        # Get relevance score or default to 0 if not in results\n",
    "        score = relevance_scores.get(i, 0.0)\n",
    "        # Apply penalty to convert to a value where irrelevant chunks have negative value\n",
    "        value = score - irrelevant_chunk_penalty\n",
    "        chunk_values.append(value)\n",
    "    \n",
    "    return chunk_values"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [],
   "source": [
    "def find_best_segments(chunk_values, max_segment_length=20, total_max_length=30, min_segment_value=0.2):\n",
    "    \"\"\"\n",
    "    Find the best segments using a variant of the maximum sum subarray algorithm.\n",
    "    \n",
    "    Args:\n",
    "        chunk_values (List[float]): Values for each chunk\n",
    "        max_segment_length (int): Maximum length of a single segment\n",
    "        total_max_length (int): Maximum total length across all segments\n",
    "        min_segment_value (float): Minimum value for a segment to be considered\n",
    "        \n",
    "    Returns:\n",
    "        List[Tuple[int, int]]: List of (start, end) indices for best segments\n",
    "    \"\"\"\n",
    "    print(\"Finding optimal continuous text segments...\")\n",
    "    \n",
    "    best_segments = []\n",
    "    segment_scores = []\n",
    "    total_included_chunks = 0\n",
    "    \n",
    "    # Keep finding segments until we hit our limits\n",
    "    while total_included_chunks < total_max_length:\n",
    "        best_score = min_segment_value  # Minimum threshold for a segment\n",
    "        best_segment = None\n",
    "        \n",
    "        # Try each possible starting position\n",
    "        for start in range(len(chunk_values)):\n",
    "            # Skip if this start position is already in a selected segment\n",
    "            if any(start >= s[0] and start < s[1] for s in best_segments):\n",
    "                continue\n",
    "                \n",
    "            # Try each possible segment length\n",
    "            for length in range(1, min(max_segment_length, len(chunk_values) - start) + 1):\n",
    "                end = start + length\n",
    "                \n",
    "                # Skip if end position is already in a selected segment\n",
    "                if any(end > s[0] and end <= s[1] for s in best_segments):\n",
    "                    continue\n",
    "                \n",
    "                # Calculate segment value as sum of chunk values\n",
    "                segment_value = sum(chunk_values[start:end])\n",
    "                \n",
    "                # Update best segment if this one is better\n",
    "                if segment_value > best_score:\n",
    "                    best_score = segment_value\n",
    "                    best_segment = (start, end)\n",
    "        \n",
    "        # If we found a good segment, add it\n",
    "        if best_segment:\n",
    "            best_segments.append(best_segment)\n",
    "            segment_scores.append(best_score)\n",
    "            total_included_chunks += best_segment[1] - best_segment[0]\n",
    "            print(f\"Found segment {best_segment} with score {best_score:.4f}\")\n",
    "        else:\n",
    "            # No more good segments to find\n",
    "            break\n",
    "    \n",
    "    # Sort segments by their starting position for readability\n",
    "    best_segments = sorted(best_segments, key=lambda x: x[0])\n",
    "    \n",
    "    return best_segments, segment_scores"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Reconstructing and Using Segments for RAG"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [],
   "source": [
    "def reconstruct_segments(chunks, best_segments):\n",
    "    \"\"\"\n",
    "    Reconstruct text segments based on chunk indices.\n",
    "    \n",
    "    Args:\n",
    "        chunks (List[str]): List of all document chunks\n",
    "        best_segments (List[Tuple[int, int]]): List of (start, end) indices for segments\n",
    "        \n",
    "    Returns:\n",
    "        List[str]: List of reconstructed text segments\n",
    "    \"\"\"\n",
    "    reconstructed_segments = []  # Initialize an empty list to store the reconstructed segments\n",
    "    \n",
    "    for start, end in best_segments:\n",
    "        # Join the chunks in this segment to form the complete segment text\n",
    "        segment_text = \" \".join(chunks[start:end])\n",
    "        # Append the segment text and its range to the reconstructed_segments list\n",
    "        reconstructed_segments.append({\n",
    "            \"text\": segment_text,\n",
    "            \"segment_range\": (start, end),\n",
    "        })\n",
    "    \n",
    "    return reconstructed_segments  # Return the list of reconstructed text segments"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {},
   "outputs": [],
   "source": [
    "def format_segments_for_context(segments):\n",
    "    \"\"\"\n",
    "    Format segments into a context string for the LLM.\n",
    "    \n",
    "    Args:\n",
    "        segments (List[Dict]): List of segment dictionaries\n",
    "        \n",
    "    Returns:\n",
    "        str: Formatted context text\n",
    "    \"\"\"\n",
    "    context = []  # Initialize an empty list to store the formatted context\n",
    "    \n",
    "    for i, segment in enumerate(segments):\n",
    "        # Create a header for each segment with its index and chunk range\n",
    "        segment_header = f\"SEGMENT {i+1} (Chunks {segment['segment_range'][0]}-{segment['segment_range'][1]-1}):\"\n",
    "        context.append(segment_header)  # Add the segment header to the context list\n",
    "        context.append(segment['text'])  # Add the segment text to the context list\n",
    "        context.append(\"-\" * 80)  # Add a separator line for readability\n",
    "    \n",
    "    # Join all elements in the context list with double newlines and return the result\n",
    "    return \"\\n\\n\".join(context)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Generating Responses with RSE Context"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {},
   "outputs": [],
   "source": [
    "def generate_response(query, context, model=\"meta-llama/Llama-3.2-3B-Instruct\"):\n",
    "    \"\"\"\n",
    "    Generate a response based on the query and context.\n",
    "    \n",
    "    Args:\n",
    "        query (str): User query\n",
    "        context (str): Context text from relevant segments\n",
    "        model (str): LLM model to use\n",
    "        \n",
    "    Returns:\n",
    "        str: Generated response\n",
    "    \"\"\"\n",
    "    print(\"Generating response using relevant segments as context...\")\n",
    "    \n",
    "    # Define the system prompt to guide the AI's behavior\n",
    "    system_prompt = \"\"\"You are a helpful assistant that answers questions based on the provided context.\n",
    "    The context consists of document segments that have been retrieved as relevant to the user's query.\n",
    "    Use the information from these segments to provide a comprehensive and accurate answer.\n",
    "    If the context doesn't contain relevant information to answer the question, say so clearly.\"\"\"\n",
    "    \n",
    "    # Create the user prompt by combining the context and the query\n",
    "    user_prompt = f\"\"\"\n",
    "Context:\n",
    "{context}\n",
    "\n",
    "Question: {query}\n",
    "\n",
    "Please provide a helpful answer based on the context provided.\n",
    "\"\"\"\n",
    "    \n",
    "    # Generate the response using the specified model\n",
    "    response = client.chat.completions.create(\n",
    "        model=model,\n",
    "        messages=[\n",
    "            {\"role\": \"system\", \"content\": system_prompt},\n",
    "            {\"role\": \"user\", \"content\": user_prompt}\n",
    "        ],\n",
    "        temperature=0\n",
    "    )\n",
    "    \n",
    "    # Return the generated response content\n",
    "    return response.choices[0].message.content"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Complete RSE Pipeline Function"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "metadata": {},
   "outputs": [],
   "source": [
    "def rag_with_rse(pdf_path, query, chunk_size=800, irrelevant_chunk_penalty=0.2):\n",
    "    \"\"\"\n",
    "    Complete RAG pipeline with Relevant Segment Extraction.\n",
    "    \n",
    "    Args:\n",
    "        pdf_path (str): Path to the document\n",
    "        query (str): User query\n",
    "        chunk_size (int): Size of chunks\n",
    "        irrelevant_chunk_penalty (float): Penalty for irrelevant chunks\n",
    "        \n",
    "    Returns:\n",
    "        Dict: Result with query, segments, and response\n",
    "    \"\"\"\n",
    "    print(\"\\n=== STARTING RAG WITH RELEVANT SEGMENT EXTRACTION ===\")\n",
    "    print(f\"Query: {query}\")\n",
    "    \n",
    "    # Process the document to extract text, chunk it, and create embeddings\n",
    "    chunks, vector_store, doc_info = process_document(pdf_path, chunk_size)\n",
    "    \n",
    "    # Calculate relevance scores and chunk values based on the query\n",
    "    print(\"\\nCalculating relevance scores and chunk values...\")\n",
    "    chunk_values = calculate_chunk_values(query, chunks, vector_store, irrelevant_chunk_penalty)\n",
    "    \n",
    "    # Find the best segments of text based on chunk values\n",
    "    best_segments, scores = find_best_segments(\n",
    "        chunk_values, \n",
    "        max_segment_length=20, \n",
    "        total_max_length=30, \n",
    "        min_segment_value=0.2\n",
    "    )\n",
    "    \n",
    "    # Reconstruct text segments from the best chunks\n",
    "    print(\"\\nReconstructing text segments from chunks...\")\n",
    "    segments = reconstruct_segments(chunks, best_segments)\n",
    "    \n",
    "    # Format the segments into a context string for the language model\n",
    "    context = format_segments_for_context(segments)\n",
    "    \n",
    "    # Generate a response from the language model using the context\n",
    "    response = generate_response(query, context)\n",
    "    \n",
    "    # Compile the result into a dictionary\n",
    "    result = {\n",
    "        \"query\": query,\n",
    "        \"segments\": segments,\n",
    "        \"response\": response\n",
    "    }\n",
    "    \n",
    "    print(\"\\n=== FINAL RESPONSE ===\")\n",
    "    print(response)\n",
    "    \n",
    "    return result"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Comparing with Standard Retrieval\n",
    "Let's implement a standard retrieval approach to compare with RSE:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "metadata": {},
   "outputs": [],
   "source": [
    "def standard_top_k_retrieval(pdf_path, query, k=10, chunk_size=800):\n",
    "    \"\"\"\n",
    "    Standard RAG with top-k retrieval.\n",
    "    \n",
    "    Args:\n",
    "        pdf_path (str): Path to the document\n",
    "        query (str): User query\n",
    "        k (int): Number of chunks to retrieve\n",
    "        chunk_size (int): Size of chunks\n",
    "        \n",
    "    Returns:\n",
    "        Dict: Result with query, chunks, and response\n",
    "    \"\"\"\n",
    "    print(\"\\n=== STARTING STANDARD TOP-K RETRIEVAL ===\")\n",
    "    print(f\"Query: {query}\")\n",
    "    \n",
    "    # Process the document to extract text, chunk it, and create embeddings\n",
    "    chunks, vector_store, doc_info = process_document(pdf_path, chunk_size)\n",
    "    \n",
    "    # Create an embedding for the query\n",
    "    print(\"Creating query embedding and retrieving chunks...\")\n",
    "    query_embedding = create_embeddings([query])[0]\n",
    "    \n",
    "    # Retrieve the top-k most relevant chunks based on the query embedding\n",
    "    results = vector_store.search(query_embedding, top_k=k)\n",
    "    retrieved_chunks = [result[\"document\"] for result in results]\n",
    "    \n",
    "    # Format the retrieved chunks into a context string\n",
    "    context = \"\\n\\n\".join([\n",
    "        f\"CHUNK {i+1}:\\n{chunk}\" \n",
    "        for i, chunk in enumerate(retrieved_chunks)\n",
    "    ])\n",
    "    \n",
    "    # Generate a response from the language model using the context\n",
    "    response = generate_response(query, context)\n",
    "    \n",
    "    # Compile the result into a dictionary\n",
    "    result = {\n",
    "        \"query\": query,\n",
    "        \"chunks\": retrieved_chunks,\n",
    "        \"response\": response\n",
    "    }\n",
    "    \n",
    "    print(\"\\n=== FINAL RESPONSE ===\")\n",
    "    print(response)\n",
    "    \n",
    "    return result"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Evaluation of RSE"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "metadata": {},
   "outputs": [],
   "source": [
    "def evaluate_methods(pdf_path, query, reference_answer=None):\n",
    "    \"\"\"\n",
    "    Compare RSE with standard top-k retrieval.\n",
    "    \n",
    "    Args:\n",
    "        pdf_path (str): Path to the document\n",
    "        query (str): User query\n",
    "        reference_answer (str, optional): Reference answer for evaluation\n",
    "    \"\"\"\n",
    "    print(\"\\n========= EVALUATION =========\\n\")\n",
    "    \n",
    "    # Run the RAG with Relevant Segment Extraction (RSE) method\n",
    "    rse_result = rag_with_rse(pdf_path, query)\n",
    "    \n",
    "    # Run the standard top-k retrieval method\n",
    "    standard_result = standard_top_k_retrieval(pdf_path, query)\n",
    "    \n",
    "    # If a reference answer is provided, evaluate the responses\n",
    "    if reference_answer:\n",
    "        print(\"\\n=== COMPARING RESULTS ===\")\n",
    "        \n",
    "        # Create an evaluation prompt to compare the responses against the reference answer\n",
    "        evaluation_prompt = f\"\"\"\n",
    "            Query: {query}\n",
    "\n",
    "            Reference Answer:\n",
    "            {reference_answer}\n",
    "\n",
    "            Response from Standard Retrieval:\n",
    "            {standard_result[\"response\"]}\n",
    "\n",
    "            Response from Relevant Segment Extraction:\n",
    "            {rse_result[\"response\"]}\n",
    "\n",
    "            Compare these two responses against the reference answer. Which one is:\n",
    "            1. More accurate and comprehensive\n",
    "            2. Better at addressing the user's query\n",
    "            3. Less likely to include irrelevant information\n",
    "\n",
    "            Explain your reasoning for each point.\n",
    "        \"\"\"\n",
    "        \n",
    "        print(\"Evaluating responses against reference answer...\")\n",
    "        \n",
    "        # Generate the evaluation using the specified model\n",
    "        evaluation = client.chat.completions.create(\n",
    "            model=\"meta-llama/Llama-3.2-3B-Instruct\",\n",
    "            messages=[\n",
    "                {\"role\": \"system\", \"content\": \"You are an objective evaluator of RAG system responses.\"},\n",
    "                {\"role\": \"user\", \"content\": evaluation_prompt}\n",
    "            ]\n",
    "        )\n",
    "        \n",
    "        # Print the evaluation results\n",
    "        print(\"\\n=== EVALUATION RESULTS ===\")\n",
    "        print(evaluation.choices[0].message.content)\n",
    "    \n",
    "    # Return the results of both methods\n",
    "    return {\n",
    "        \"rse_result\": rse_result,\n",
    "        \"standard_result\": standard_result\n",
    "    }"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      "========= EVALUATION =========\n",
      "\n",
      "\n",
      "=== STARTING RAG WITH RELEVANT SEGMENT EXTRACTION ===\n",
      "Query: What is 'Explainable AI' and why is it considered important?\n",
      "Extracting text from document...\n",
      "Chunking text into non-overlapping segments...\n",
      "Created 42 chunks\n",
      "Generating embeddings for chunks...\n",
      "\n",
      "Calculating relevance scores and chunk values...\n",
      "Finding optimal continuous text segments...\n",
      "Found segment (21, 41) with score 9.0718\n",
      "Found segment (0, 20) with score 8.8685\n",
      "\n",
      "Reconstructing text segments from chunks...\n",
      "Generating response using relevant segments as context...\n",
      "\n",
      "=== FINAL RESPONSE ===\n",
      "Based on the context provided, Explainable AI (XAI) refers to the development of techniques that make AI systems more transparent and understandable. The goal of XAI is to provide insights into how AI models make decisions, enhancing trust and accountability in AI systems.\n",
      "\n",
      "XAI is considered important for several reasons:\n",
      "\n",
      "1. **Building trust**: XAI helps users understand how AI systems arrive at their decisions, which is essential for building trust in AI. When users can see how AI systems work, they are more likely to accept the results.\n",
      "2. **Addressing bias**: XAI can help identify biases in AI systems by providing insights into how they make decisions. By understanding how AI systems work, developers can identify and address biases in the data they are trained on.\n",
      "3. **Improving accountability**: XAI enables developers to take responsibility for the decisions made by AI systems. By providing explanations for AI decisions, developers can be held accountable for any errors or biases in the system.\n",
      "4. **Enhancing transparency**: XAI provides insights into how AI systems work, which is essential for transparency in AI decision-making. This is particularly important in high-stakes applications, such as healthcare or finance, where users need to understand how AI systems arrive at their decisions.\n",
      "\n",
      "Overall, XAI is considered important because it addresses the need for transparency, accountability, and trust in AI systems. By providing insights into how AI models make decisions, XAI can help build trust, address bias, and improve accountability in AI development and deployment.\n",
      "\n",
      "=== STARTING STANDARD TOP-K RETRIEVAL ===\n",
      "Query: What is 'Explainable AI' and why is it considered important?\n",
      "Extracting text from document...\n",
      "Chunking text into non-overlapping segments...\n",
      "Created 42 chunks\n",
      "Generating embeddings for chunks...\n",
      "Creating query embedding and retrieving chunks...\n",
      "Generating response using relevant segments as context...\n",
      "\n",
      "=== FINAL RESPONSE ===\n",
      "Based on the provided context, Explainable AI (XAI) is a technique that aims to make AI decisions more understandable, enabling users to assess their fairness and accuracy. XAI techniques are designed to provide insights into how AI systems arrive at their decisions, enhancing transparency and explainability.\n",
      "\n",
      "XAI is considered important for several reasons:\n",
      "\n",
      "1. **Building trust in AI**: By making AI decisions more understandable, XAI helps build trust in AI systems, which is essential for their widespread adoption.\n",
      "2. **Addressing potential harms**: XAI can help identify potential biases and errors in AI decision-making, allowing for more effective mitigation and prevention of harms.\n",
      "3. **Ensuring accountability**: XAI provides a way to establish accountability for AI decisions, which is crucial for addressing potential consequences and ensuring ethical behavior.\n",
      "4. **Improving fairness and accuracy**: By providing insights into AI decision-making, XAI can help identify and address biases and errors, leading to more fair and accurate outcomes.\n",
      "\n",
      "Overall, Explainable AI is a critical aspect of developing trustworthy, fair, and accurate AI systems, and its importance will only continue to grow as AI becomes increasingly pervasive in various domains.\n",
      "\n",
      "=== COMPARING RESULTS ===\n",
      "Evaluating responses against reference answer...\n",
      "\n",
      "=== EVALUATION RESULTS ===\n",
      "Based on the comparison, I would conclude that:\n",
      "\n",
      "1. **The Response from Standard Retrieval is more accurate and comprehensive:**\n",
      "   The Response from Standard Retrieval provides a clear definition of Explainable AI (XAI) and its importance. It explains the goals of XAI, its key aspects (transparency and understandability), and its benefits. The explanation highlights the reasons why XAI is considered important, which includes building trust, addressing potential harms, ensuring accountability, and improving fairness and accuracy.\n",
      "\n",
      "   On the other hand, the Response from Relevant Segment Extraction provides a simplified explanation of XAI but focuses more on the aspects of trust, bias, accountability, and transparency. While it does provide a clear overview of XAI, it is more concise and somewhat less detailed than the Response from Standard Retrieval.\n",
      "\n",
      "2. **The Response from Relevant Segment Extraction is better at addressing the user's query:**\n",
      "   The Response from Standard Retrieval not only answers the question but also provides additional context and importance. The Response from Relevant Segment Extraction, however, more closely addresses the original question by focusing on the core aspects of XAI and its advantages.\n",
      "\n",
      "3. **The Response from Standard Retrieval is less likely to include irrelevant information:**\n",
      "   This response is less likely to contain unnecessary details. The Response from Standard Retrieval provides a clear and concise answer with a clear structure, focusing on the key points of XAI and its importance. In contrast, the Response from Relevant Segment Extraction might include some details about AI systems, updates, or other related information that is not essential to answering the original question.\n",
      "\n",
      "However, the Response from Standard Retrieval includes a more relevant and comprehensive discussion of the reference answer's points, than the Response from Relevant Segment Extraction, which is closer to the reference answer but strays closer to an ensuing span-specific update concerning explainable AI underpoint of the start precisuliar generation machinery.\n"
     ]
    }
   ],
   "source": [
    "# Load the validation data from a JSON file\n",
    "with open('data/val.json') as f:\n",
    "    data = json.load(f)\n",
    "\n",
    "# Extract the first query from the validation data\n",
    "query = data[0]['question']\n",
    "\n",
    "# Extract the reference answer from the validation data\n",
    "reference_answer = data[0]['ideal_answer']\n",
    "\n",
    "# pdf_path\n",
    "pdf_path = \"data/AI_Information.pdf\"\n",
    "\n",
    "# Run evaluation\n",
    "results = evaluate_methods(pdf_path, query, reference_answer)"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": ".venv-new-specific-rag",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.9.0"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
